{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef2525df-6418-493c-bfed-cecfc934a467",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import os\n",
    "from os.path import join\n",
    "\n",
    "import dask.dataframe as dd\n",
    "import dask.array as da\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "from dask_ml.decomposition import IncrementalPCA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e91c67be-8ca7-4496-94d8-b6a71299e560",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "PATH = '/mnt/dssmcmlfs01/merlin_cxg_2023_05_15_sf-log1p'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "417a896b-5667-43fc-a018-c34acf9ea27d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def get_count_matrix(ddf):\n",
    "    x = (\n",
    "        ddf['X']\n",
    "        .map_partitions(\n",
    "            lambda xx: pd.DataFrame(np.vstack(xx.tolist())), \n",
    "            meta={col: 'f4' for col in range(19331)}\n",
    "        )\n",
    "        .to_dask_array(lengths=[1024] * ddf.npartitions)\n",
    "    )\n",
    "\n",
    "    return x\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f0f99d72-6183-412b-b06b-423c53f5b06a",
   "metadata": {},
   "source": [
    "# Compute PCA for visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43fb96d0-7155-4e7d-8428-2757183474a0",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "os.makedirs(join(PATH, 'pca'), exist_ok=True)\n",
    "\n",
    "\n",
    "n_comps = 50\n",
    "\n",
    "\n",
    "for split in ['test', 'val', 'train']:\n",
    "    x = get_count_matrix(dd.read_parquet(join(PATH, split), split_row_groups=True))\n",
    "    pca = IncrementalPCA(n_components=n_comps, iterated_power=3)\n",
    "    x_pca = da.compute(pca.fit_transform(x))[0]\n",
    "    with open(join(PATH, 'pca', f'x_pca_{split}_{n_comps}.npy'), 'wb') as f:\n",
    "        np.save(f, x_pca)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dea42c83-eabf-4a4a-886a-f7de38411343",
   "metadata": {},
   "source": [
    "# Compute PCA for model training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fed1c4a0-5c42-4911-9e04-d55ae4bca680",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "os.makedirs(join(PATH, 'pca'), exist_ok=True)\n",
    "\n",
    "\n",
    "n_comps = 256\n",
    "\n",
    "\n",
    "x_train = get_count_matrix(dd.read_parquet(join(PATH, 'train'), split_row_groups=True))\n",
    "x_val = get_count_matrix(dd.read_parquet(join(PATH, 'val'), split_row_groups=True))\n",
    "x_test = get_count_matrix(dd.read_parquet(join(PATH, 'test'), split_row_groups=True))\n",
    "\n",
    "\n",
    "pca = IncrementalPCA(n_components=n_comps, iterated_power=3)\n",
    "x_pca_train, x_pca_val, x_pca_test = da.compute(\n",
    "    [pca.fit_transform(x_train), pca.transform(x_val), pca.transform(x_test)]\n",
    ")[0]\n",
    "\n",
    "\n",
    "with open(join(PATH, f'pca/x_pca_training_train_split_{n_comps}.npy'), 'wb') as f:\n",
    "    np.save(f, x_pca_train)\n",
    "with open(join(PATH, f'pca/x_pca_training_val_split_{n_comps}.npy'), 'wb') as f:\n",
    "    np.save(f, x_pca_val)\n",
    "with open(join(PATH, f'pca/x_pca_training_test_split_{n_comps}.npy'), 'wb') as f:\n",
    "    np.save(f, x_pca_test)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b997843e-7ba9-4b77-8303-d8c0ab0c5525",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
